{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "view-in-github"
   },
   "source": [
    "<a href=\"https://colab.research.google.com/github/lokwq/TextBrewer/blob/add_note_examples/sst2.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "UMExDS48VN58"
   },
   "source": [
    "This notebook shows how to fine-tune a model on sst-2 dataset and how to distill the model with TextBrewer.\n",
    "\n",
    "Detailed Docs can be find here:\n",
    "https://github.com/airaria/TextBrewer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "oVTjuvH0rPsT"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "device='cuda' if torch.cuda.is_available() else 'cpu' "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "yIgAk4WVrKtC"
   },
   "outputs": [],
   "source": [
    "!pip install transformers #4.8.2\n",
    "!pip install datasets\n",
    "!pip install textbrewer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "qqu-aNtc3QgP"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "from transformers import BertForSequenceClassification, BertTokenizer,BertConfig\n",
    "from sklearn.metrics import accuracy_score, precision_recall_fscore_support\n",
    "from transformers import Trainer, TrainingArguments\n",
    "from datasets import load_dataset,load_metric"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "h5ww8ad58D8v"
   },
   "source": [
    "### Prepare dataset to train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "9-8wYOHG4WVq"
   },
   "outputs": [],
   "source": [
    "train_dataset = load_dataset('glue', 'sst2', split='train')\n",
    "val_dataset = load_dataset('glue', 'sst2', split='validation')\n",
    "test_dataset = load_dataset('glue', 'sst2', split='test')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "iQSki-hv5Imc"
   },
   "outputs": [],
   "source": [
    "train_dataset = train_dataset.map(lambda examples: {'labels': examples['label']}, batched=True)\n",
    "val_dataset = val_dataset.map(lambda examples: {'labels': examples['label']}, batched=True)\n",
    "test_dataset = test_dataset.map(lambda examples: {'labels': examples['label']}, batched=True)\n",
    "val_dataset = val_dataset.remove_columns(['label'])\n",
    "test_dataset = test_dataset.remove_columns(['label'])\n",
    "train_dataset = train_dataset.remove_columns(['label'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "whL22dsx5QU5"
   },
   "outputs": [],
   "source": [
    "model = BertForSequenceClassification.from_pretrained('bert-base-uncased')\n",
    "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "eH4rBumG5i6S"
   },
   "outputs": [],
   "source": [
    "MAX_LENGTH = 128\n",
    "train_dataset = train_dataset.map(lambda e: tokenizer(e['sentence'], truncation=True, padding='max_length', max_length=MAX_LENGTH), batched=True)\n",
    "val_dataset = val_dataset.map(lambda e: tokenizer(e['sentence'], truncation=True, padding='max_length', max_length=MAX_LENGTH), batched=True)\n",
    "test_dataset = test_dataset.map(lambda e: tokenizer(e['sentence'], truncation=True, padding='max_length', max_length=MAX_LENGTH), batched=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "id": "Nv-gsKvG5ylO"
   },
   "outputs": [],
   "source": [
    "train_dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])\n",
    "val_dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])\n",
    "test_dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "id": "6jwP2aHv6EU6"
   },
   "outputs": [],
   "source": [
    "def compute_metrics(pred):\n",
    "    labels = pred.label_ids\n",
    "    preds = pred.predictions.argmax(-1)\n",
    "    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='macro')\n",
    "    acc = accuracy_score(labels, preds)\n",
    "    return {\n",
    "        'accuracy': acc,\n",
    "        'f1': f1,\n",
    "        'precision': precision,\n",
    "        'recall': recall\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "KonAbPBj6NCK"
   },
   "outputs": [],
   "source": [
    "#start training \n",
    "training_args = TrainingArguments(\n",
    "    output_dir='/content/drive/MyDrive/results',          #output directory\n",
    "    learning_rate=1e-4,\n",
    "    num_train_epochs=5,              \n",
    "    per_device_train_batch_size=32,                #batch size per device during training\n",
    "    per_device_eval_batch_size=32,                #batch size for evaluation\n",
    "    logging_dir='/content/drive/MyDrive/logs',            \n",
    "    logging_steps=100,\n",
    "    do_train=True,\n",
    "    do_eval=True,\n",
    "    no_cuda=False,\n",
    "    load_best_model_at_end=True,\n",
    "    # eval_steps=100,\n",
    "    evaluation_strategy=\"epoch\"\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,                         \n",
    "    args=training_args,                  \n",
    "    train_dataset=train_dataset,         \n",
    "    eval_dataset=val_dataset,            \n",
    "    compute_metrics=compute_metrics\n",
    ")\n",
    "\n",
    "train_out = trainer.train()\n",
    "\n",
    "#after training, you could find traing logs and checpoints in your own dirve. also you can reset the file address in training args"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "1H8Dod2y6R8c"
   },
   "outputs": [],
   "source": [
    "torch.save(model.state_dict(), '/content/drive/MyDrive/sst2_teacher_model.pt')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Gov66CaFNAgg"
   },
   "source": [
    "### Start distillation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "IA-gwQKNB8fs"
   },
   "outputs": [],
   "source": [
    "train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32) #prepare dataloader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "YD8qPZmUiTKH"
   },
   "outputs": [],
   "source": [
    "import textbrewer\n",
    "from textbrewer import GeneralDistiller\n",
    "from textbrewer import TrainingConfig, DistillationConfig\n",
    "from transformers import BertForSequenceClassification, BertConfig, AdamW,BertTokenizer\n",
    "from transformers import get_linear_schedule_with_warmup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "4emuX8UK8Mup"
   },
   "source": [
    "Initialize the student model by BertConfig and prepare the teacher model.\n",
    "\n",
    "bert_config_L3.json refers to a 3-layer Bert.\n",
    "\n",
    "bert_config.json refers to a standard 12-layer Bert."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "CKLaqSPCiX1a"
   },
   "outputs": [],
   "source": [
    "bert_config_T3 = BertConfig.from_json_file('/content/drive/MyDrive/TextBrewer-master/examples/student_config/bert_base_cased_config/bert_config_L3.json')#相对路径\n",
    "bert_config_T3.output_hidden_states = True\n",
    "\n",
    "student_model = BertForSequenceClassification(bert_config_T3) #, num_labels = 2\n",
    "student_model.to(device=device)\n",
    "\n",
    "\n",
    "bert_config = BertConfig.from_json_file('/content/drive/MyDrive/TextBrewer-master/examples/student_config/bert_base_cased_config/bert_config.json')\n",
    "bert_config.output_hidden_states = True\n",
    "teacher_model = BertForSequenceClassification(bert_config) #, num_labels = 2\n",
    "teacher_model.load_state_dict(torch.load('/content/drive/MyDrive/sst2_teacher_model.pt'))\n",
    "teacher_model.to(device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "W6SuVnpa8RAm"
   },
   "source": [
    "The cell below is to distill the teacher model to student model you prepared.\n",
    "\n",
    "After the code execution is complete, the distilled model will be in 'saved_model' in colab file list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "CIxaegSUikGX"
   },
   "outputs": [],
   "source": [
    "num_epochs = 20\n",
    "num_training_steps = len(train_dataloader) * num_epochs\n",
    "# Optimizer and learning rate scheduler\n",
    "optimizer = AdamW(student_model.parameters(), lr=1e-4)\n",
    "\n",
    "scheduler_class = get_linear_schedule_with_warmup\n",
    "# arguments dict except 'optimizer'\n",
    "scheduler_args = {'num_warmup_steps':int(0.1*num_training_steps), 'num_training_steps':num_training_steps}\n",
    "\n",
    "\n",
    "def simple_adaptor(batch, model_outputs):\n",
    "    return {'logits': model_outputs.logits, 'hidden': model_outputs.hidden_states}\n",
    "\n",
    "distill_config = DistillationConfig(\n",
    "    intermediate_matches=[    \n",
    "     {'layer_T':0, 'layer_S':0, 'feature':'hidden', 'loss': 'hidden_mse','weight' : 1},\n",
    "     {'layer_T':8, 'layer_S':2, 'feature':'hidden', 'loss': 'hidden_mse','weight' : 1}])\n",
    "train_config = TrainingConfig()\n",
    "\n",
    "distiller = GeneralDistiller(\n",
    "    train_config=train_config, distill_config=distill_config,\n",
    "    model_T=teacher_model, model_S=student_model, \n",
    "    adaptor_T=simple_adaptor, adaptor_S=simple_adaptor)\n",
    "\n",
    "\n",
    "with distiller:\n",
    "    distiller.train(optimizer, train_dataloader, num_epochs, scheduler_class=scheduler_class, scheduler_args = scheduler_args, callback=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "F8acpGydEgLf",
    "outputId": "79a0c44f-7f03-4d6d-b09b-84a858fa1360"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 65,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_model = BertForSequenceClassification(bert_config_T3)\n",
    "test_model.load_state_dict(torch.load('/content/drive/MyDrive/gs4210.pkl'))#gs4210 is the distilled model weights file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "_5QYCFnAMkpE"
   },
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "eval_dataloader = DataLoader(val_dataset, batch_size=8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "u0Pb4CeJdLCk"
   },
   "outputs": [],
   "source": [
    "metric= load_metric(\"accuracy\")\n",
    "test_model.eval()\n",
    "for batch in eval_dataloader:\n",
    "    batch = {k: v for k, v in batch.items()}\n",
    "    with torch.no_grad():\n",
    "        outputs = test_model(**batch)\n",
    "\n",
    "    logits = outputs.logits\n",
    "    predictions = torch.argmax(logits, dim=-1)\n",
    "    metric.add_batch(predictions=predictions, references=batch[\"labels\"])\n",
    "\n",
    "metric.compute()"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "authorship_tag": "ABX9TyPSRVIK8638b2CgsGZ/nsAR",
   "collapsed_sections": [],
   "include_colab_link": true,
   "mount_file_id": "1LgVQBkBlDbyTgriuZ88TDXPA6MSUKN3W",
   "name": "sst2_bert_fin.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
